% reference: Multi-instance Domain Adaptation for Vaccine Adverse Event Detection
% Junxiang Wang, Liang Zhao. 27th International World Wide Web Conference (WWW 2018).
% @misc{Wang2018,
%    author  = "Junxiang Wang, Liang Zhao",
%    title   = "Multi-instance Domain Adaptation for Vaccine %Adverse Event Detection",
%    conference = "Proceedings of 27th International World Wide Web  Conference",
%    year    = "2018",
%    month   = "apr"
%}
% contact Junxiang Wang(jwang40.gmu.edu)
addpath(genpath('liblinear-2.11'));
clear;
load data.mat
% data: n*k matrix, Twitter data.
% InstanceIndex: n*1 vector, The mapping from tweets to users. 
% label: n*1 vector, the label of users:  1 means positive and 0 means negative.
% R: r*k matrix, the set of formal reports.
% where:
% n=number of tweets.
% k=number of keywords.
% r=number of formal reports.
rho =10;
lambda1 = 0.01;
lambda2 = 10;
nfold=5;
c=100;
maxIter=20;
maxCount=10;
rng('default');
data_num=length(label);
% 5-fold cross validation
indices = crossvalind('Kfold', data_num, nfold);
% forming training set and test set
    testIdx = find(indices == 1);
    trainIdx = find(indices ~= 1);
    training=[];
    trainInstanceIndex=[];
    train_sum=[];
    for j=1:length(trainIdx)
        index=find(InstanceIndex==trainIdx(j));
        training=[training;data(index,:)];
        trainInstanceIndex=[trainInstanceIndex;j*ones(length(index),1)];
        train_sum=[train_sum;sum(data(index,:),1)];
    end
    test=[];
    testInstanceIndex=[];
    for j=1:length(testIdx)
        index=find(InstanceIndex==testIdx(j));
        test=[test;data(index,:)];
        testInstanceIndex=[testInstanceIndex;j*ones(length(index),1)];
    end
% Initialization of beta using liblinear library
beta = train(label(trainIdx),sparse(train_sum),'-s 6 -q');
% beta: a (k+1)*1 vector, the first element is the intercept while the
% remaining is the weight of keywords.
beta=[beta.bias;beta.w'];
% training the MIDA model
[~,~,beta,~,~] =MIDA(training,trainInstanceIndex,R,label(trainIdx),rho,lambda1,lambda2,c,maxIter,maxCount,beta);
% test the performance
[CM,auc,ROCX,ROCY,aupr,PRX,PRY]=MIDATEST(beta,test,testInstanceIndex,label(testIdx));
% plot ROC curve
figure(1);
plot(ROCX,ROCY);
xlabel('FPR');
ylabel('TPR');
% plot PR cuve
figure(2);
plot(PRX,PRY,'r');
xlabel('Recall');
ylabel('Precision');